{
 "nbformat": 4,
 "nbformat_minor": 0,
 "metadata": {
  "colab": {
   "name": "animations_playground.ipynb",
   "provenance": [],
   "collapsed_sections": []
  },
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3"
  },
  "language_info": {
   "name": "python"
  },
  "accelerator": "GPU"
 },
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "pRIb9Xqjnmxn"
   },
   "source": [
    "<a href=\"https://kaggle.com/kernels/welcome?src=https://github.com/yuval-alaluf/restyle-encoder/blob/main/notebooks/animations_playground.ipynb\"><img align=\"left\" alt=\"Kaggle\" title=\"Open in Kaggle\" src=\"https://kaggle.com/static/images/open-in-kaggle.svg\"></a><a href=\"https://colab.research.google.com/github/yuval-alaluf/restyle-encoder/blob/main/notebooks/animations_playground.ipynb\"><img align=\"left\" title=\"Open in Colab\" src=\"https://colab.research.google.com/assets/colab-badge.svg\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "VcNK15ganhUH"
   },
   "source": [
    "import os\n",
    "os.chdir('/content')\n",
    "CODE_DIR = 'restyle-encoder'"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "reo0WG2JnnJv"
   },
   "source": [
    "!git clone https://github.com/yuval-alaluf/restyle-encoder.git $CODE_DIR"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "n2NZkkHWntAf"
   },
   "source": [
    "!wget https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip\n",
    "!sudo unzip ninja-linux.zip -d /usr/local/bin/\n",
    "!sudo update-alternatives --install /usr/bin/ninja ninja /usr/local/bin/ninja 1 --force"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "ybuAB7yDntS5"
   },
   "source": [
    "os.chdir(f'./{CODE_DIR}')"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "YHH5jMEVn2qz"
   },
   "source": [
    "from argparse import Namespace\n",
    "import time\n",
    "import os\n",
    "import sys\n",
    "import pprint\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "from PIL import Image\n",
    "import torch\n",
    "import torchvision.transforms as transforms\n",
    "\n",
    "import imageio\n",
    "import matplotlib\n",
    "from IPython.display import HTML\n",
    "from base64 import b64encode\n",
    "\n",
    "sys.path.append(\".\")\n",
    "sys.path.append(\"..\")\n",
    "\n",
    "from utils.common import tensor2im\n",
    "from utils.inference_utils import run_on_batch\n",
    "from models.psp import pSp\n",
    "from models.e4e import e4e\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Ba4ovOESo1Su"
   },
   "source": [
    "## Step 1: Select Experiment Type\n",
    "Select which experiment you wish to generate animations on:"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "cellView": "form",
    "id": "_wO0FrBNo07X"
   },
   "source": [
    "#@title Select which experiment you wish to perform inference on: { run: \"auto\" }\n",
    "experiment_type = 'ffhq_encode' #@param ['ffhq_encode', 'cars_encode', 'church_encode', 'horse_encode', 'afhq_wild_encode', 'toonify']"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "4etDz82xkTJz"
   },
   "source": [
    "## Step 2: Prepare to Download Pretrained Models \n",
    "As part of this repository, we provide pretrained models for each of the above experiments. Here, we'll create the download command needed for downloading the desired model.\n",
    "\n",
    "Note: in this notebook, we'll be using ReStyle applied over pSp for all domains except for the horses domain where we'll be using e4e. This is done since e4e is generally able to generate more realistic reconstructions on this domain. "
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "KSnjlBZOkTJ0"
   },
   "source": [
    "def get_download_model_command(file_id, file_name):\n",
    "    \"\"\" Get wget download command for downloading the desired model and save to directory ../pretrained_models. \"\"\"\n",
    "    current_directory = os.getcwd()\n",
    "    save_path = os.path.join(os.path.dirname(current_directory), CODE_DIR, \"pretrained_models\")\n",
    "    if not os.path.exists(save_path):\n",
    "        os.makedirs(save_path)\n",
    "    url = r\"\"\"wget --load-cookies /tmp/cookies.txt \"https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id={FILE_ID}' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\\1\\n/p')&id={FILE_ID}\" -O {SAVE_PATH}/{FILE_NAME} && rm -rf /tmp/cookies.txt\"\"\".format(FILE_ID=file_id, FILE_NAME=file_name, SAVE_PATH=save_path)\n",
    "    return url    "
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "m4sjldFMkTJ5"
   },
   "source": [
    "MODEL_PATHS = {\n",
    "    \"ffhq_encode\": {\"id\": \"1sw6I2lRIB0MpuJkpc8F5BJiSZrc0hjfE\", \"name\": \"restyle_psp_ffhq_encode.pt\"},\n",
    "    \"cars_encode\": {\"id\": \"1zJHqHRQ8NOnVohVVCGbeYMMr6PDhRpPR\", \"name\": \"restyle_psp_cars_encode.pt\"},\n",
    "    \"church_encode\": {\"id\": \"1bcxx7mw-1z7dzbJI_z7oGpWG1oQAvMaD\", \"name\": \"restyle_psp_church_encode.pt\"},\n",
    "    \"horse_encode\": {\"id\": \"19_sUpTYtJmhSAolKLm3VgI-ptYqd-hgY\", \"name\": \"restyle_e4e_horse_encode.pt\"},\n",
    "    \"afhq_wild_encode\": {\"id\": \"1GyFXVTNDUw3IIGHmGS71ChhJ1Rmslhk7\", \"name\": \"restyle_psp_afhq_wild_encode.pt\"},\n",
    "    \"toonify\": {\"id\": \"1GtudVDig59d4HJ_8bGEniz5huaTSGO_0\", \"name\": \"restyle_psp_toonify.pt\"}\n",
    "}\n",
    "\n",
    "path = MODEL_PATHS[experiment_type]\n",
    "download_command = get_download_model_command(file_id=path[\"id\"], file_name=path[\"name\"]) "
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "9Tozsg81kTKA"
   },
   "source": [
    "## Step 3: Define Inference Parameters"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "XIhyc7RqkTKB"
   },
   "source": [
    "Below we have a dictionary defining parameters such as the path to the pretrained model to use and the path to the image to perform inference on.  \n",
    "While we provide default values to run this script, feel free to change as needed."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "2kE5y1-skTKC"
   },
   "source": [
    "EXPERIMENT_DATA_ARGS = {\n",
    "    \"ffhq_encode\": {\n",
    "        \"model_path\": \"pretrained_models/restyle_psp_ffhq_encode.pt\",\n",
    "        \"image_path\": \"notebooks/images/face_img.jpg\",\n",
    "        \"transform\": transforms.Compose([\n",
    "            transforms.Resize((256, 256)),\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])\n",
    "    },\n",
    "    \"cars_encode\": {\n",
    "        \"model_path\": \"pretrained_models/restyle_psp_cars_encode.pt\",\n",
    "        \"image_path\": \"notebooks/images/car_img.jpg\",\n",
    "        \"transform\": transforms.Compose([\n",
    "            transforms.Resize((192, 256)),\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])\n",
    "    },\n",
    "    \"church_encode\": {\n",
    "        \"model_path\": \"pretrained_models/restyle_psp_church_encode.pt\",\n",
    "        \"image_path\": \"notebooks/images/church_img.jpg\",\n",
    "        \"transform\": transforms.Compose([\n",
    "            transforms.Resize((256, 256)),\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])\n",
    "    },\n",
    "    \"horse_encode\": {\n",
    "        \"model_path\": \"pretrained_models/restyle_e4e_horse_encode.pt\",\n",
    "        \"image_path\": \"notebooks/images/horse_img.jpg\",\n",
    "        \"transform\": transforms.Compose([\n",
    "            transforms.Resize((256, 256)),\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])\n",
    "    },\n",
    "    \"afhq_wild_encode\": {\n",
    "        \"model_path\": \"pretrained_models/restyle_psp_afhq_wild_encode.pt\",\n",
    "        \"image_path\": \"notebooks/images/afhq_wild_img.jpg\",\n",
    "        \"transform\": transforms.Compose([\n",
    "            transforms.Resize((256, 256)),\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])\n",
    "    },\n",
    "    \"toonify\": {\n",
    "        \"model_path\": \"pretrained_models/restyle_psp_toonify.pt\",\n",
    "        \"image_path\": \"notebooks/images/toonify_img.jpg\",\n",
    "        \"transform\": transforms.Compose([\n",
    "            transforms.Resize((256, 256)),\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])\n",
    "    },\n",
    "}"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "IzUHoD9ukTKG"
   },
   "source": [
    "EXPERIMENT_ARGS = EXPERIMENT_DATA_ARGS[experiment_type]"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "DjrgneZZnd49"
   },
   "source": [
    "To reduce the number of requests to fetch the model, we'll check if the model was previously downloaded and saved before downloading the model.  \n",
    "We'll download the model for the selected experiment and save it to the folder `../pretrained_models`.\n",
    "\n",
    "We also need to verify that the model was downloaded correctly. All of our models should weigh approximately 800MB - 1GB.  \n",
    "Note that if the file weighs several KBs, you most likely encounter a \"quota exceeded\" error from Google Drive. In that case, you should try downloading the model again after a few hours."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "jQ31J_m7kTJ8"
   },
   "source": [
    "if not os.path.exists(EXPERIMENT_ARGS['model_path']) or os.path.getsize(EXPERIMENT_ARGS['model_path']) < 1000000:\n",
    "    print(f'Downloading ReStyle model for {experiment_type}...')\n",
    "    os.system(f\"wget {download_command}\")\n",
    "    # if google drive receives too many requests, we'll reach the quota limit and be unable to download the model\n",
    "    if os.path.getsize(EXPERIMENT_ARGS['model_path']) < 1000000:\n",
    "        raise ValueError(\"Pretrained model was unable to be downloaded correctly!\")\n",
    "    else:\n",
    "        print('Done.')\n",
    "else:\n",
    "    print(f'ReStyle model for {experiment_type} already exists!')"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "TAWrUehTkTKJ"
   },
   "source": [
    "## Step 4: Load Pretrained Model\n",
    "We assume that you have downloaded all relevant models and placed them in the directory defined by the above dictionary."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "1t-AOhP1kTKJ"
   },
   "source": [
    "model_path = EXPERIMENT_ARGS['model_path']\n",
    "ckpt = torch.load(model_path, map_location='cpu')"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "2UBwJ3dJkTKM"
   },
   "source": [
    "opts = ckpt['opts']\n",
    "pprint.pprint(opts)"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "EMKhWoFKkTKS"
   },
   "source": [
    "# update the training options\n",
    "opts['checkpoint_path'] = model_path"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "6hccfNizkTKW"
   },
   "source": [
    "opts = Namespace(**opts)\n",
    "if experiment_type == 'horse_encode': \n",
    "    net = e4e(opts)\n",
    "else:\n",
    "    net = pSp(opts)\n",
    "    \n",
    "net.eval()\n",
    "net.cuda()\n",
    "print('Model successfully loaded!')"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "mdQ0fXM7ppSr"
   },
   "source": [
    "## Define Utility Functions"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "CJbNCfLaplKu"
   },
   "source": [
    "def generate_mp4(out_name, images, kwargs):\n",
    "  writer = imageio.get_writer(out_name + '.mp4', **kwargs)\n",
    "  for image in images:\n",
    "      writer.append_data(image)\n",
    "  writer.close()\n",
    "\n",
    "\n",
    "def run_on_batch_to_vecs(inputs, net, opts):\n",
    "  opts.resize_outputs = False\n",
    "  opts.n_iters_per_batch = 5\n",
    "  with torch.no_grad():\n",
    "      _, result_batch = run_on_batch(inputs.to(\"cuda\").float(), net, opts, avg_image)\n",
    "  return result_batch[0][-1]\n",
    "\n",
    "\n",
    "def get_result_from_vecs(vectors_a, vectors_b, alpha):\n",
    "  results = []\n",
    "  for i in range(len(vectors_a)):\n",
    "      with torch.no_grad():\n",
    "          cur_vec = vectors_b[i] * alpha + vectors_a[i] * (1 - alpha)\n",
    "          res = net(torch.from_numpy(cur_vec).cuda().unsqueeze(0), randomize_noise=False,\n",
    "                    input_code=True, input_is_full=True, resize=False)\n",
    "          results.append(res[0])\n",
    "  return results\n",
    "\n",
    "def show_mp4(filename, width):\n",
    "    mp4 = open(filename + '.mp4', 'rb').read()\n",
    "    data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n",
    "    display(HTML(\"\"\"\n",
    "    <video width=\"%d\" controls autoplay loop>\n",
    "        <source src=\"%s\" type=\"video/mp4\">\n",
    "    </video>\n",
    "    \"\"\" % (width, data_url)))"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "cgGQ7grCp9JI"
   },
   "source": [
    "# get the image corresponding to the latent average\n",
    "avg_image = net(net.latent_avg.unsqueeze(0),\n",
    "                input_code=True,\n",
    "                randomize_noise=False,\n",
    "                return_latents=False,\n",
    "                average_code=True)[0]\n",
    "avg_image = avg_image.to('cuda').float().detach()\n",
    "if opts.dataset_type == \"cars_encode\":\n",
    "    avg_image = avg_image[:, 32:224, :]"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "p6MdqDhDqDhb"
   },
   "source": [
    "SEED = 42\n",
    "np.random.seed(SEED)\n",
    "\n",
    "img_transforms = EXPERIMENT_ARGS['transform']\n",
    "root_dir = \"notebooks/images/\"\n",
    "image_names = ['', '', '', '', '']\n",
    "image_paths = [os.path.join(root_dir, image) + '.jpg' for image in image_names]"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "pvsWbbOytp4a"
   },
   "source": [
    "### Align the images if needed. You can skip this step if working on non-face images or if your images are pre-aligned."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "MYvyGVgTt2eA"
   },
   "source": [
    "ALIGN_IMAGES = False"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "57mIkFsbuMaa"
   },
   "source": [
    "def run_alignment(image_path):\n",
    "    import dlib\n",
    "    from scripts.align_faces_parallel import align_face\n",
    "    if not os.path.exists(\"shape_predictor_68_face_landmarks.dat\"):\n",
    "        print('Downloading files for aligning face image...')\n",
    "        os.system('wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2')\n",
    "        os.system('bzip2 -dk shape_predictor_68_face_landmarks.dat.bz2')\n",
    "        print('Done.')\n",
    "    predictor = dlib.shape_predictor(\"shape_predictor_68_face_landmarks.dat\")\n",
    "    aligned_image = align_face(filepath=image_path, \n",
    "                               predictor=predictor,\n",
    "                               output_size=256, \n",
    "                               transform_size=256) \n",
    "    print(\"Aligned image has shape: {}\".format(aligned_image.size))\n",
    "    return aligned_image "
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "_Bpzbr_Etoow"
   },
   "source": [
    "# only align images if working on faces and if specified\n",
    "if ALIGN_IMAGES and experiment_type in [\"ffhq_encode\", \"toonify\"]: \n",
    "  aligned_image_paths = []\n",
    "  for image_name, image_path in zip(image_names, image_paths): \n",
    "    print(f'Aligning {image_name}...')\n",
    "    aligned_image = run_alignment(image_path)\n",
    "    aligned_path = os.path.join(root_dir, f'{image_name}_aligned.jpg')\n",
    "    # save the aligned image\n",
    "    aligned_image.save(aligned_path)\n",
    "    aligned_image_paths.append(aligned_path)\n",
    "  # use the save aligned images as our input image paths\n",
    "  image_paths = aligned_image_paths"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "nA8pC10Yry6U"
   },
   "source": [
    "Run inference on each image and save the inverted latent codes."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "CFQyaD1Qqe3d"
   },
   "source": [
    "in_images = []\n",
    "all_vecs = []\n",
    "\n",
    "if experiment_type == \"cars_encode\":\n",
    "  resize_amount = (512, 384)\n",
    "else:\n",
    "  resize_amount = (opts.output_size, opts.output_size)\n",
    "\n",
    "for image_path in image_paths:\n",
    "    print(f'Working on {os.path.basename(image_path)}...')\n",
    "    original_image = Image.open(image_path)\n",
    "    original_image = original_image.convert(\"RGB\")\n",
    "    input_image = img_transforms(original_image)\n",
    "    with torch.no_grad():\n",
    "        result_vec = run_on_batch_to_vecs(input_image.unsqueeze(0), net, opts)\n",
    "    all_vecs.append([result_vec])\n",
    "    in_images.append(original_image.resize(resize_amount))"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "wV7yaSXar88q"
   },
   "source": [
    "Interpolate!"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "TsA3azDuqtWd"
   },
   "source": [
    "n_transition = 25\n",
    "if experiment_type == \"cars_encode\":\n",
    "  SIZE = 384\n",
    "else:\n",
    "  SIZE = opts.output_size\n",
    "\n",
    "images = []\n",
    "image_paths.append(image_paths[0])\n",
    "all_vecs.append(all_vecs[0])\n",
    "in_images.append(in_images[0])\n",
    "\n",
    "for i in range(1, len(image_paths)):\n",
    "    if i == 0:\n",
    "        alpha_vals = [0] * 10 + np.linspace(0, 1, n_transition).tolist() + [1] * 5\n",
    "    else:\n",
    "        alpha_vals = [0] * 5 + np.linspace(0, 1, n_transition).tolist() + [1] * 5\n",
    "\n",
    "    for alpha in tqdm(alpha_vals):\n",
    "        image_a = np.array(in_images[i - 1])\n",
    "        image_b = np.array(in_images[i])\n",
    "        image_joint = np.zeros_like(image_a)\n",
    "        up_to_row = int((SIZE - 1) * alpha)\n",
    "        if up_to_row > 0:\n",
    "            image_joint[:(up_to_row + 1), :, :] = image_b[((SIZE - 1) - up_to_row):, :, :]\n",
    "        if up_to_row < (SIZE - 1):\n",
    "            image_joint[up_to_row:, :, :] = image_a[:(SIZE - up_to_row), :, :]\n",
    "\n",
    "        result_image = get_result_from_vecs(all_vecs[i - 1], all_vecs[i], alpha)[0]\n",
    "        if experiment_type == \"cars_encode\":\n",
    "          result_image = result_image[:, 64:448, :]\n",
    "\n",
    "        output_im = tensor2im(result_image)\n",
    "        res = np.concatenate([image_joint, np.array(output_im)], axis=1)\n",
    "        images.append(res)"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "YqvG0oJtsUWt"
   },
   "source": [
    "kwargs = {'fps': 15}\n",
    "save_path = \"notebooks/animations\"\n",
    "os.makedirs(save_path, exist_ok=True)\n",
    "\n",
    "gif_path = os.path.join(save_path, f\"{experiment_type}_gif\")\n",
    "generate_mp4(gif_path, images, kwargs)\n",
    "show_mp4(gif_path, width=opts.output_size)"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "fhN1Jw9C02Ah"
   },
   "source": [
    ""
   ],
   "execution_count": null,
   "outputs": []
  }
 ]
}